iT邦幫忙

2024 iThome 鐵人賽

DAY 25
0
自我挑戰組

菜鳥AI工程師給碩班學弟妹的挑戰系列 第 25

[Day 25] fastapi + lightning -> LitServe - 3

  • 分享至 

  • xImage
  •  

前情提要: 昨天有利用LitServe將bert model架起來使用,以及測試檔案上傳的部分。

補上 bert 參考: https://leemeng.tw/attack_on_bert_transfer_learning_in_nlp.html

官方範例: https://lightning.ai/lightning-ai/studios/deploy-a-noise-cancellation-api-with-deepfilternet

我昨天發完文章後繼續研究,將每個範例大致看過,在上面的範例為語音增強,上傳檔案然後inference最後回傳檔案,可以發現在上傳檔案與昨天我自己寫的一樣,看來這部分是沒有問題。

接下來試試看用我們之前的model來部屬看看。

1. Server.py

在之前的程式,基本上我們是將load image等等放在dataset當中,此時我們可以簡化,直接放在setup,不過如果dataset比較複雜,這部分就不知怎麼處理,一樣是採用dataloader嗎? 這部分可能要等官方有更多範例或之後有做這塊。

from fastapi import Request, Response
from litserve import LitAPI, LitServer
import torchvision.transforms as transforms
from infer import example
from PIL import Image
import torch

class SimpleFileLitAPI(LitAPI):
    def setup(self, device):
        self.transform = transforms.Compose([
            transforms.Resize((28, 28)),  # 確保圖片大小一致
            transforms.ToTensor(),        # 轉換為PyTorch張量
            transforms.Normalize((0.5, ), (0.5, ))  # 標準化
        ])
        self.path = '0_upload.jpg'

        self.model = example()
        ckpt = torch.load('last.ckpt', map_location = torch.device(device))["state_dict"]
        self.model.load_state_dict(ckpt, strict = False)
        self.model.eval()

    def decode_request(self, request: Request):
        with open(self.path, "wb+") as file_object:
            file_object.write(request["file"].file.read())
        return 1 

    def predict(self, x):
        image = Image.open(self.path).convert('L')  # MNIST是灰度圖,轉換為'L'模式
        image = self.transform(image)
        return self.model(image)

    def encode_response(self, output) -> Response:
        return {"output": output.item()}


if __name__ == "__main__":
    server = LitServer(SimpleFileLitAPI(), accelerator="cpu", workers_per_device = 1)
    server.run(port = 8000, num_api_servers = 1)

2. infer.py

將infer.py做些修改,將原先predict_step的東西放到forward,為了讓server.py的predict,可以直接使用self.model去做

import lightning as pl
from model import MNISTClassifier

class example(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MNISTClassifier()

    def forward(self, batch):
        pred = self.model(batch).argmax(dim = -1)
        return pred.cpu().detach()

接下來就可以將server開起來,然後用昨天的client.py程式,去上傳0.jpg去測試
https://ithelp.ithome.com.tw/upload/images/20240829/20168446OqGuvzs1Xu.png

看起來跑自己的model是沒有問題的。

今天就先到這裡囉~


上一篇
[Day 24] fastapi + lightning -> LitServe - 2
下一篇
[Day 26] LitServe總結 + umap圖示化bert
系列文
菜鳥AI工程師給碩班學弟妹的挑戰30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言